from jpype import JArray, JInt
from prompta.core.alphabet.verbs import VERBS
from prompta.utils.java_libs import Integer
from prompta.core.alphabet.events import EVENTS, is_event
from pipelines.prompta.utils import load_dfa, query2str, show_dfa



class Skill:

    def __init__(self, name, automaton_path):
        self.program_name = name
        self.automaton_path = automaton_path

        self.setup_automaton()

    def setup_automaton(self):
        self.automaton_text = open(self.automaton_path).read()
        model_data = load_dfa(self.automaton_path)
        self.automaton = model_data.model

        self.alphabet = model_data.alphabet
        self.states = [_ for _ in self.automaton.getStates()]
        self.current_state = self.automaton.getIntInitialState()
        self.events = {}
        for a in self.alphabet:
            a = str(a)
            for event in EVENTS:
                if event.accept(a):
                    self.events[a] = event.copy(a)
                    break

        inf = int(2 ** 16 - 1)
        size = self.automaton.size()
        self.accDists = [inf for _ in range(size)]

        for i in range(size):
            i = JInt(i)
            if self.automaton.isAccepting(i):
                self.accDists[i] = 0

        while True:
            stable = True
            for state in range(size):
                succMinDist = inf
                for sym in self.alphabet:
                    trans = self.automaton.getTransition(JInt(state), JInt(self.alphabet.getSymbolIndex(sym)))
                    succ = self.automaton.getIntSuccessor(trans)
                    succMinDist = min(succMinDist, self.accDists[succ])
                if succMinDist == inf:
                    continue
                succMinDist += 1
                if succMinDist < self.accDists[state]:
                    self.accDists[state] = succMinDist
                    stable = False
            if stable:
                break

        self.reject_state_dist = inf
    
    def reset(self):
        self.current_state = self.automaton.getIntInitialState()

    @property
    def next_symbol(self):
        valid_inputs = []
        for a in self.alphabet:
            trans = self.automaton.getTransition(JInt(self.current_state), JInt(self.alphabet.getSymbolIndex(a)))
            if trans is not None and self.accDists[self.automaton.getIntSuccessor(trans)] != self.reject_state_dist:
                valid_inputs.append(str(a))
        for valid_input in valid_inputs:
            if valid_input == "NOOP":
                continue
            elif is_event(valid_input):
                return valid_input, False
            else:
                return valid_input, True
            
    def step(self, symbol: str):
        print("step: ", symbol)
        trans = self.automaton.getTransition(JInt(self.current_state), JInt(self.alphabet.getSymbolIndex(symbol)))
        self.current_state = self.automaton.getIntSuccessor(trans)
        success = self.automaton.isAccepting(self.current_state)
        done = success or self.accDists[self.current_state] == self.reject_state_dist

        return success, done


if __name__ == "__main__":
    dfa_path = r"E:\llm-automata\PROMPTA\pipelines\prompta\test\mc_ckpt_template\skills\equipaxe.taf"
    abp = Skill("equipaxe", dfa_path)
    print(abp.states, abp.current_state)
    for i in range(20):
        next_symbol, is_action = abp.next_symbol
        if is_action:
            print("action: ", next_symbol)
            for verb in VERBS:
                if verb.accept(next_symbol):
                    print(verb.to_code(next_symbol))
        else:
            print("event: ", next_symbol)
            next_symbol = "NOOP"
        success, done = abp.step(next_symbol)
        if done:
            break
    print(success, done)
